import pandas as pd
import torch
import torchvision.transforms.functional as transform
import torchvision.transforms.functional as F
from EnsembleXAI import Ensemble, Metrics
from torchvision.transforms import Resize, CenterCrop
import os
from PIL import Image
from torchvision.models import resnet50, ResNet50_Weights
import urllib.request
import json
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
from captum.attr import IntegratedGradients, Occlusion, NoiseTunnel, visualization as viz, Saliency
import matplotlib.pyplot as plt
with urllib.request.urlopen("https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json") as url:
imagenet_classes_dict = json.load(url)
def download_class_images(class_id, masks_path):
full_path = masks_path + class_id + "\\"
kaggle_path = f"/ILSVRC/Data/CLS-LOC/train/{class_id}/"
for file_name in os.listdir(full_path):
file_name_jpeg = file_name[:-3] + 'JPEG'
!kaggle competitions download -f {kaggle_path}{file_name_jpeg} -p ./images/{class_id}/ -c imagenet-object-localization-challenge
def download():
skipped = []
for class_id in os.listdir(masks_dir):
image_class_path = os.path.join(images_dir, class_id)
if os.path.exists(image_class_path) and len(os.listdir(image_class_path)) == 10:
skipped.append(class_id)
continue
#download_class_images(class_id, masks_dir)
print(f"Downloaded {class_id}")
print("Full dirs: " + str(len(skipped)))
def images_list(image_path, resize=True):
_crop = CenterCrop(224).forward
_resize = Resize([232,232]).forward
images = []
for image_name in os.listdir(image_path):
image = Image.open(image_path + image_name)
if resize:
image = _crop(_resize(image))
images.append(image)
return images
def dict_to_matrix(original_data, explanations_dict, predictor, masks_tensor):
df = pd.DataFrame()
#impact_thresh = 0.4
#accordance_thresh = 0.2
for key, value in explanations_dict.items():
for thresh in range(10):
df.loc[key, f"Decision Impact Ratio{thresh}"] = Metrics.decision_impact_ratio(original_data, predictor, value, thresh/10, 0)
df.loc[key, f"Confidence Impact Ratio Same{thresh}"] = Metrics.confidence_impact_ratio(original_data, predictor, value, thresh/10, 0, compare_to="same_prediction")
df.loc[key, f"CIR Max{thresh}"] = Metrics.confidence_impact_ratio(original_data, predictor, value, thresh/10, 0, compare_to="new_prediction")
df.loc[key, f"Average Recall{thresh}"] = torch.mean(Metrics.accordance_recall(value, masks_tensor, thresh/10)).item()
df.loc[key, f"Average Precision{thresh}"] = torch.mean(Metrics.accordance_precision(value, masks_tensor, thresh/10)).item()
df.loc[key, "F1_score"] = Metrics.F1_score(explanations_dict[key], masks_tensor)
df.loc[key, "IOU"] = Metrics.intersection_over_union(explanations_dict[key], masks_tensor)
return df
input_dir = "\\".join(os.getcwd().split(sep="\\")[:-2] + ['input'])
masks_dir = input_dir + f'\\ImageNetS50\\train-semi-segmentation\\'
images_dir = os.getcwd() + "\\images\\"
print(os.listdir(images_dir))
['expl_n01491361.pickle', 'n01443537', 'n01491361', 'n01491361.png', 'n01531178', 'n01644373', 'n02104029', 'n02119022', 'n02123597', 'n02133161', 'n02165456', 'n02281406', 'n02325366', 'n02342885', 'n02396427', 'n02483362', 'n02504458', 'n02510455', 'n02690373', 'n02747177', 'n02783161', 'n02814533', 'n02859443', 'n02917067', 'n02992529', 'n03014705', 'n03047690', 'n03095699', 'n03197337', 'n03201208', 'n03445777', 'n03452741', 'n03584829', 'n03630383', 'n03775546', 'n03791053', 'n03874599', 'n03891251', 'n04026417', 'n04335435', 'n04380533', 'n04404412', 'n04447861', 'n04507155', 'n04522168', 'n04557648', 'n04562935', 'n04612504', 'n06794110', 'n07749582', 'n07831146', 'n12998815']
id = "n01491361"
def load_all(classid):
all_img = images_list(images_dir + classid + "\\")
all_img_org = images_list(images_dir + classid + "\\", resize=False)
all_tens = [F.to_tensor(img) for img in all_img]
all_msks = [(F.to_tensor(img)>0).float() for img in images_list(masks_dir + classid + "\\")]
tens_img = torch.stack(all_tens)
tens_msks = torch.stack(all_msks)[:,0].unsqueeze(dim=1).repeat(1, tens_img.shape[1], 1, 1)
return all_img, all_img_org, all_tens, all_msks, tens_img, tens_msks
all_images, all_images_original, all_tensors, all_masks, tensor_images, tensor_masks = load_all(id)
photos = []
for tensor, mask in zip(all_tensors, all_masks):
photo = torch.cat([tensor, mask], dim=2)
photos.append(photo)
display(transform.to_pil_image(torch.cat(photos, dim=1)))
model = resnet50(weights=ResNet50_Weights.DEFAULT)
model.eval()
resnet_transform = ResNet50_Weights.DEFAULT.transforms()
pipeline = lambda images: torch.stack([resnet_transform(image) for image in images])
proper_data = pipeline(all_images_original)
outputs2 = model(proper_data)
_, preds2 = torch.max(outputs2, 1)
probs2 = torch.nn.functional.softmax(outputs2, dim=1)
[imagenet_classes_dict[str(i.item())][1] for i in preds2] # gar = Niszczukokształtne
['tiger_shark', 'tiger_shark', 'tiger_shark', 'great_white_shark', 'tiger_shark', 'tiger_shark', 'tiger_shark', 'tiger_shark', 'hammerhead', 'gar']
single_pred = preds2[2].unsqueeze(dim=0)
single_data = proper_data[2].unsqueeze(dim=0)
integrated_gradients = IntegratedGradients(model)
attributions_ig = integrated_gradients.attribute(single_data, target=single_pred, n_steps=200)
transformed_img = resnet_transform(all_images_original[2])
default_cmap = LinearSegmentedColormap.from_list('custom blue',
[(0, '#ffffff'),
(0.25, '#000000'),
(1, '#000000')], N=256)
_ = viz.visualize_image_attr(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
method='heat_map',
cmap=default_cmap,
show_colorbar=True,
sign='positive',
outlier_perc=1)
display(all_images[2])
import gc
gc.collect()
190
noise_tunnel = NoiseTunnel(integrated_gradients)
attributions_ig_nt = noise_tunnel.attribute(single_data, nt_samples=5, nt_type='smoothgrad_sq', target=single_pred)
attributions_ig_nt_all = torch.cat([noise_tunnel.attribute(tensor_images[i].unsqueeze(dim=0), nt_samples=5, nt_type='smoothgrad_sq', target=preds2[i].unsqueeze(dim=0)) for i in range(10)], dim=0)
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig_nt.squeeze().numpy(), (1,2,0)),
np.array(all_images[2]),
["heat_map", "original_image"],
["positive", "all"],
cmap=default_cmap,
show_colorbar=True)
occlusion = Occlusion(model)
attributions_occ = occlusion.attribute(single_data,
strides = (3, 8, 8),
target=single_pred,
sliding_window_shapes=(3, 15, 15),
baselines=0)
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
np.array(all_images[2]),
["heat_map", "original_image"],
["positive", "all"],
show_colorbar=True,
outlier_perc=2,
)
occlusion = Occlusion(model)
attributions_occ2 = occlusion.attribute(single_data,
strides = (3, 20, 20),
target=single_pred,
sliding_window_shapes=(3, 25, 25),
baselines=0)
_2 = viz.visualize_image_attr_multiple(np.transpose(attributions_occ2.squeeze().cpu().detach().numpy(), (1,2,0)),
np.array(all_images[2]),
["heat_map", "original_image"],
["all", "positive"],
show_colorbar=True,
outlier_perc=2,
)
occlusion = Occlusion(model)
attributions_occ_all_25 = occlusion.attribute(tensor_images,
strides = (3, 20, 20),
target = preds2,
sliding_window_shapes = (3, 25, 25),
baselines=0)
attributions_occ_all_15 = occlusion.attribute(tensor_images,
strides = (3, 8, 8),
target = preds2,
sliding_window_shapes = (3, 15, 15),
baselines=0)
saliency = Saliency(model)
attr_saliency = saliency.attribute(tensor_images, target=preds2)
_3 = viz.visualize_image_attr_multiple(np.transpose(attr_saliency[2].numpy(), (1,2,0)),
np.array(all_images[2]),
["heat_map", "original_image"],
["positive", "positive"],
show_colorbar=True,
outlier_perc=2,
)
def sample_xai(images):
if images.shape[0] == 1:
target = single_pred
else:
target = single_pred.repeat(images.shape[0])
xai = occlusion.attribute(images,
strides = (3, 40, 40),
target=target,
sliding_window_shapes=(3,50, 50),
baselines=0)
return xai
Metrics.stability(sample_xai, single_data.squeeze(dim=0), single_data.repeat(10,1,1,1))
0.0
x = torch.cat([attributions_occ, attributions_ig_nt])
aggregated1 = Ensemble.basic(x, aggregating_func='avg')
aggregated2 = Ensemble.basic(x, aggregating_func='min')
aggregated3 = Ensemble.basic(x, aggregating_func='max')
#display(transform.to_pil_image(aggregated1[0]))
_ = viz.visualize_image_attr_multiple(np.transpose(aggregated1.squeeze().cpu().detach().numpy(), (1,2,0)),
np.array(all_images[2]),
["heat_map", "original_image", "masked_image"],
["all", "positive", "positive"],
show_colorbar=True,
outlier_perc=2,
)
all_stacked = torch.stack([attributions_occ_all_15, attributions_occ_all_25], dim=1)
attr_agg_avg = Ensemble.basic(all_stacked, aggregating_func='avg')
attr_agg_min = Ensemble.basic(all_stacked, aggregating_func='min')
attr_agg_max = Ensemble.basic(all_stacked, aggregating_func='max')
plt.style.use('fast')
def plot_explanations(images, explanations_dict, classes_predicted,
#cmaps=[default_cmap, None, None, None, None, None, None],
method = "heat_map"):
nrow, ncol = len(images), len(explanations_dict.keys())+1
fig, ax = plt.subplots(nrows = nrow, ncols = ncol, figsize=(14, 3*nrow))
columns_names = ["Original"] + list(explanations_dict.keys())
for col, col_name in zip(ax[0], columns_names):
col.title.set_text(col_name)
for i, img in enumerate(images):
ax[i,0].xaxis.set_ticks_position("none")
ax[i,0].yaxis.set_ticks_position("none")
ax[i,0].set_yticklabels([])
ax[i,0].set_xticklabels([])
ax[i,0].imshow(np.array(images[i]), vmin=0, vmax=255)
ax[i,0].set_ylabel(classes_predicted[i], size='large')
for j, (col, (key, explanations)) in enumerate(zip(ax[i,1:], explanations_dict.items())):
#ith image, jth explanation
#expl = explanations[i,j]
expl = explanations[j]
sign = "all"
cmap=None
if expl.amin() >= 0:
sign = "positive"
cmap = default_cmap
_ = viz.visualize_image_attr(np.transpose(expl.squeeze().numpy(), (1,2,0)),
original_image=np.array(img),
method=method,
sign=sign,
plt_fig_axis=(fig, col),
show_colorbar=True,
outlier_perc=2,
cmap=cmap,
use_pyplot=False
)
plt.savefig(f"images/{id}.png")
plt.show()
expl_dict = {"Gradients":attributions_ig_nt_all, "Saliency":attr_saliency, "Occlusion 25":attributions_occ_all_25,
"Occlusion 15":attributions_occ_all_15, "Max Aggregate":attr_agg_max,
"Min Aggregate":attr_agg_min, "Avg Aggregate":attr_agg_avg}
explanations_three = torch.cat([all_stacked, attr_saliency.unsqueeze(dim=1), attributions_occ_all_15.unsqueeze(dim=1),attr_agg_max.unsqueeze(dim=1), attr_agg_min.unsqueeze(dim=1), attr_agg_avg.unsqueeze(dim=1)], dim=1)
predicted_names = [imagenet_classes_dict[str(i.item())][1] for i in preds2]
plot_explanations(all_images, expl_dict, predicted_names, method="blended_heat_map")
predict = lambda x: torch.nn.Softmax(dim=0)(model(x))
dict_to_matrix(proper_data, expl_dict, predict, tensor_masks)
| Decision Impact Ratio0 | Confidence Impact Ratio Same0 | CIR Max0 | Average Recall0 | Average Precision0 | Decision Impact Ratio1 | Confidence Impact Ratio Same1 | CIR Max1 | Average Recall1 | Average Precision1 | ... | CIR Max8 | Average Recall8 | Average Precision8 | Decision Impact Ratio9 | Confidence Impact Ratio Same9 | CIR Max9 | Average Recall9 | Average Precision9 | F1_score | IOU | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Gradients | 1.0 | 0.292043 | 0.292043 | 1.000000 | 0.239076 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | ... | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.358053 | 0.000000 |
| Saliency | 1.0 | 0.292043 | 0.292043 | 1.000000 | 0.239318 | 0.8 | 0.130759 | -0.075646 | 0.200246 | 0.477015 | ... | -0.009288 | 0.000763 | 0.428025 | 0.2 | 0.016281 | -0.000817 | 0.000368 | 0.318214 | 0.358287 | 0.004982 |
| Occlusion 25 | 0.8 | 0.126705 | -0.086325 | 0.567629 | 0.260502 | 0.8 | 0.084125 | -0.002501 | 0.266416 | 0.432520 | ... | -0.002693 | 0.000890 | 0.100000 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.302462 | 0.023134 |
| Occlusion 15 | 1.0 | 0.243223 | -0.117268 | 0.482177 | 0.238699 | 0.9 | 0.135679 | -0.027287 | 0.143442 | 0.459755 | ... | 0.000890 | 0.000057 | 0.100000 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.277534 | 0.004646 |
| Max Aggregate | 1.0 | 0.237470 | -0.106383 | 0.668883 | 0.246700 | 0.8 | 0.157906 | -0.039502 | 0.305328 | 0.435404 | ... | -0.001771 | 0.000947 | 0.200000 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.314795 | 0.024842 |
| Min Aggregate | 0.9 | 0.168878 | -0.132847 | 0.380922 | 0.254946 | 0.5 | 0.097367 | 0.023040 | 0.104529 | 0.442264 | ... | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.256129 | 0.002419 |
| Avg Aggregate | 1.0 | 0.185669 | -0.134782 | 0.536309 | 0.250199 | 0.7 | 0.092850 | -0.034366 | 0.192898 | 0.436170 | ... | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.291462 | 0.006486 |
7 rows × 52 columns
per_image = pd.DataFrame()
per_image["Average Recall"] = torch.mean(torch.stack([Metrics.accordance_recall(expl, tensor_masks, 0.2) for expl in expl_dict.values()]), dim=0).numpy()
per_image["Average Precision"] = torch.mean(torch.stack([Metrics.accordance_precision(expl, tensor_masks, 0.2) for expl in expl_dict.values()]), dim=0).numpy()
per_image["Consistency"] = [Metrics.consistency(expls) for expls in torch.stack(list(expl_dict.values()), dim=1)]
per_image
| Average Recall | Average Precision | Consistency | |
|---|---|---|---|
| 0 | 0.107373 | 0.373676 | 0.015499 |
| 1 | 0.029311 | 0.664992 | 0.025282 |
| 2 | 0.004816 | 0.371717 | 0.026888 |
| 3 | 0.335163 | 0.201642 | 0.018533 |
| 4 | 0.052924 | 0.847267 | 0.027571 |
| 5 | 0.004855 | 0.056777 | 0.023819 |
| 6 | 0.012087 | 0.446683 | 0.047301 |
| 7 | 0.000784 | 0.045918 | 0.033685 |
| 8 | 0.047417 | 0.172541 | 0.022880 |
| 9 | 0.189759 | 0.348346 | 0.020937 |
import pickle
a=expl_dict
with open(f'images/expl_{id}.pickle', 'wb') as handle:
pickle.dump(a, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(f'images/expl_{id}.pickle', 'rb') as handle:
b = pickle.load(handle)
b.keys()
dict_keys(['Gradients', 'Saliency', 'Occlusion 25', 'Occlusion 15', 'Max Aggregate', 'Min Aggregate', 'Avg Aggregate'])